Assignment 5 - Explainability in Machine Learning

Week 5 - PM034A Machine learning for socio-technical systems

By Nadia Metoui
Faculty of Technology, Policy, and Management (TPM)

Learning Objectives
Implement and Examine Intrinsic Explainable ML Models.
Implement and Examine Post-Hoc Explainability Models.
Discuss explainability in a Multi-stakeholders Socio-Technical context.

Assignment Scenario
The healthcare sector is a high-risk and highly regulated sector. Many stakeholders are involved in ensuring these regulations are respected ad risks maintained to a minimum. Machin learning has many advantages and can be valuable for many health applications. It should, however, be transparent, trustworthy and trusted.

In this assignment, you will train a classifier to predict if a patient is at risk of diabetes or not. This model will be used by physicians before selecting candidates for a new drug trial. This drug might have adverse effects on individuals with risks of developing diabetes. Therefore, these individuals should not be on the list of candidates.

After training the model you should ensure the model and its results are explainable to the several stakeholders (including the development team, the hospital management, the doctors and the candidates/patients).

Concretely you will practice different types of explainability we saw in class (during Lecture and the Lab 5); you will observe the explanations of each type of explainer; and discuss their utility (for different stakeholders) and their limitations.

Assignment Steps

This assignment is composed of four parts.

  • Part 0: Setting up the Lab (Not graded. Code Provided)
  • Part I: Intrinsic Explainability (20* points)
  • Part II: LIME: Locally Interpetable Model-agnostic Explanations (40* points)
  • Part III: SHAP SHapley Additive Explanation (40* points)

Each part containes coding taskes To Code, and textual answers To Answer

*The total number of points out of 100 points will be normalized to calculate your average over 10 points

Sumbission Instructions

  • Answer the questions (code and text) in the notebook Assignment_4.ipynb
  • Add as many cells as needed
  • Rename the notebook (ipynb) by adding your name and surname (Assignmnet5\<name_surname>. e.g. Assignmnet_5_nadia_metoui.ipynb)
  • creat an HTML version of you notbook (fully executed).
  • Submit your work in zip file with the ipynb and HTML (fully executed) in Brightspace

Part 0: Setting up the Lab (Not graded, Code Provided)¶

Install and Load the libraries for the Lab.

In [78]:
!pip install lime
!pip install shap
Requirement already satisfied: lime in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (0.2.0.1)
Requirement already satisfied: scikit-learn>=0.18 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from lime) (1.0.2)
Requirement already satisfied: tqdm in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from lime) (4.64.1)
Requirement already satisfied: scipy in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from lime) (1.7.3)
Requirement already satisfied: numpy in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from lime) (1.21.6)
Requirement already satisfied: scikit-image>=0.12 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from lime) (0.19.3)
Requirement already satisfied: matplotlib in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from lime) (3.5.3)
Requirement already satisfied: networkx>=2.2 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from scikit-image>=0.12->lime) (2.6.3)
Requirement already satisfied: PyWavelets>=1.1.1 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from scikit-image>=0.12->lime) (1.3.0)
Requirement already satisfied: tifffile>=2019.7.26 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from scikit-image>=0.12->lime) (2021.11.2)
Requirement already satisfied: pillow!=7.1.0,!=7.1.1,!=8.3.0,>=6.1.0 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from scikit-image>=0.12->lime) (9.3.0)
Requirement already satisfied: imageio>=2.4.1 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from scikit-image>=0.12->lime) (2.23.0)
Requirement already satisfied: packaging>=20.0 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from scikit-image>=0.12->lime) (21.3)
Requirement already satisfied: threadpoolctl>=2.0.0 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from scikit-learn>=0.18->lime) (3.1.0)
Requirement already satisfied: joblib>=0.11 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from scikit-learn>=0.18->lime) (1.2.0)
Requirement already satisfied: fonttools>=4.22.0 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from matplotlib->lime) (4.38.0)
Requirement already satisfied: pyparsing>=2.2.1 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from matplotlib->lime) (3.0.9)
Requirement already satisfied: python-dateutil>=2.7 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from matplotlib->lime) (2.8.2)
Requirement already satisfied: cycler>=0.10 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from matplotlib->lime) (0.11.0)
Requirement already satisfied: kiwisolver>=1.0.1 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from matplotlib->lime) (1.4.4)
Requirement already satisfied: colorama in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from tqdm->lime) (0.4.6)
Requirement already satisfied: typing-extensions in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from kiwisolver>=1.0.1->matplotlib->lime) (4.4.0)
Requirement already satisfied: six>=1.5 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from python-dateutil>=2.7->matplotlib->lime) (1.16.0)
Requirement already satisfied: shap in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (0.41.0)
Requirement already satisfied: scikit-learn in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from shap) (1.0.2)
Requirement already satisfied: numpy in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from shap) (1.21.6)
Requirement already satisfied: tqdm>4.25.0 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from shap) (4.64.1)
Requirement already satisfied: slicer==0.0.7 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from shap) (0.0.7)
Requirement already satisfied: scipy in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from shap) (1.7.3)
Requirement already satisfied: numba in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from shap) (0.56.4)
Requirement already satisfied: pandas in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from shap) (1.3.5)
Requirement already satisfied: cloudpickle in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from shap) (2.2.0)
Requirement already satisfied: packaging>20.9 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from shap) (21.3)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from packaging>20.9->shap) (3.0.9)
Requirement already satisfied: colorama in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from tqdm>4.25.0->shap) (0.4.6)
Requirement already satisfied: importlib-metadata in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from numba->shap) (5.0.0)
Requirement already satisfied: setuptools in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from numba->shap) (65.5.0)
Requirement already satisfied: llvmlite<0.40,>=0.39.0dev0 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from numba->shap) (0.39.1)
Requirement already satisfied: python-dateutil>=2.7.3 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from pandas->shap) (2.8.2)
Requirement already satisfied: pytz>=2017.3 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from pandas->shap) (2022.6)
Requirement already satisfied: threadpoolctl>=2.0.0 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from scikit-learn->shap) (3.1.0)
Requirement already satisfied: joblib>=0.11 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from scikit-learn->shap) (1.2.0)
Requirement already satisfied: six>=1.5 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from python-dateutil>=2.7.3->pandas->shap) (1.16.0)
Requirement already satisfied: zipp>=0.5 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from importlib-metadata->numba->shap) (3.10.0)
Requirement already satisfied: typing-extensions>=3.6.4 in c:\users\hugod\anaconda3\envs\tpmml\lib\site-packages (from importlib-metadata->numba->shap) (4.4.0)
In [79]:
%matplotlib inline
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import warnings
from IPython.display import Markdown, display

import seaborn as sns
import sklearn.model_selection
import sklearn.metrics
import sklearn.datasets
import sklearn.ensemble
import sklearn.preprocessing
from sklearn.metrics import accuracy_score

import xgboost
from xgboost import plot_importance


import lime
import lime.lime_tabular

import shap
from shap.plots import _waterfall




np.random.seed(1)

Load and Prepare the Data.¶

For this assignment you will be using the Pima Indians Diabetes Database We use the preprosessed version published in the kaggel website (here)

Load the dataset

In [149]:
#Note you have to mode the csv file to the apporpriate folder or change the path "/content/data/diabetes.csv" in the code below
#Load data from CSV file to a data frame
df_diabetes = pd.read_csv("data/diabetes.csv")
df_diabetes.shape
Out[149]:
(768, 9)

Check dataset¶

In [81]:
df_diabetes.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 768 entries, 0 to 767
Data columns (total 9 columns):
 #   Column                    Non-Null Count  Dtype  
---  ------                    --------------  -----  
 0   Pregnancies               768 non-null    int64  
 1   Glucose                   768 non-null    int64  
 2   BloodPressure             768 non-null    int64  
 3   SkinThickness             768 non-null    int64  
 4   Insulin                   768 non-null    int64  
 5   BMI                       768 non-null    float64
 6   DiabetesPedigreeFunction  768 non-null    float64
 7   Age                       768 non-null    int64  
 8   Outcome                   768 non-null    int64  
dtypes: float64(2), int64(7)
memory usage: 54.1 KB

Here we see we have no null values, the dataset consists of 8 columns of medical / personal data which might indicate that a person is at risk for diabetes which is indicated by the 8th column. In total, we have data of 768 patients.

Explore Data¶

[Optional - Not Graded ]
We recommand you make some data exploration to get familiar with the attributes and the values
This part is not graded but if you provide great data visualization and exploration ideas you might get a bonus :) </small>

In [82]:
df_diabetes.head()
Out[82]:
Pregnancies Glucose BloodPressure SkinThickness Insulin BMI DiabetesPedigreeFunction Age Outcome
0 6 148 72 35 0 33.6 0.627 50 1
1 1 85 66 29 0 26.6 0.351 31 0
2 8 183 64 0 0 23.3 0.672 32 1
3 1 89 66 23 94 28.1 0.167 21 0
4 0 137 40 35 168 43.1 2.288 33 1
In [83]:
rows = 3
columns = 3
fig, axes = plt.subplots(rows, columns, figsize=(8, 8))
fig.set_tight_layout(True)
i, j = 0, 0
all_columns = df_diabetes.columns
for c in all_columns:
    if c == 'Outcome':
         sns.countplot(data=df_diabetes, x='Outcome')
    sns.boxplot(data=df_diabetes, y=c, x='Outcome', ax=axes[i][j])
    axes[i][j].set_title(c)
    j += 1
    if j % 3 == 0:
        i += 1
        j = 0

Drop 0 values¶

In total we deleted 44 records with incorrectly measured features.

In [152]:
#Here we drop the 0 values that we have just observed in the dataset
print(df_diabetes.shape)
df_diabetes.drop(df_diabetes[(df_diabetes.Glucose <= 0) | (df_diabetes.BMI <= 0) | (df_diabetes.BloodPressure <= 0)].index, inplace=True)
df_diabetes.reset_index(inplace=True, drop=True)
print(df_diabetes.shape)
(724, 9)
(724, 9)
In [153]:
df_diabetes.head()
Out[153]:
Pregnancies Glucose BloodPressure SkinThickness Insulin BMI DiabetesPedigreeFunction Age Outcome
0 6 148 72 35 0 33.6 0.627 50 1
1 1 85 66 29 0 26.6 0.351 31 0
2 8 183 64 0 0 23.3 0.672 32 1
3 1 89 66 23 94 28.1 0.167 21 0
4 0 137 40 35 168 43.1 2.288 33 1

Interpretation of data vis¶

From the boxplots we can observe that some values are 0 that should not be 0. Namely, there are some 0 values for the Glucose, BMI and Blood Pressure features which cannot be possible. Furthermore, people with high risk of diabetes generally have increased age, pregnancies, bmi, insulin, glucose and blood pressure levels. It is an unbalanced dataset, since there are generally fewer people with risk of diabetes (around 290) and 500 people with no risk of diabetes.

In [132]:
""#We can use a heatmap to determine features that have highest correlation with SHARE_HIGH.
# heatmap of correlations
# Create plot
fig, axes = plt.subplots(2, 2, figsize=(12, 12))
fig.set_tight_layout(True)

# Compute correlation matrix
corr = df_diabetes.corr()

# Create upper triangular matrix to mask the upper triangular part of the heatmap
corr_mask = np.triu(np.ones_like(corr, dtype=bool))

# Generate a custom diverging colormap (because it looks better)
corr_cmap = sns.diverging_palette(230, 20, as_cmap=True)

sns.heatmap(corr, mask = corr_mask, cmap=corr_cmap, annot=True,square = True, linewidths=.5, ax = axes[0][0])
axes[0][0].set_title("Correlation plot of dataset features")


sns.kdeplot(data = df_diabetes, x='BMI', y='BloodPressure', ax=axes[0][1], fill=True)
axes[0][1].set_title("Scatterplot of BMI against blood pressure")



df_low = df_diabetes[df_diabetes.Glucose <= df_diabetes.Glucose.mean()]
df_high = df_diabetes[df_diabetes.Glucose > df_diabetes.Glucose.mean()]
sns.countplot(df_low, x='Outcome', ax= axes[1][0])
sns.countplot(df_high, x='Outcome', ax= axes[1][1])
axes[1][0].set_title("Outcome of people with <= average glucose levels")
axes[1][1].set_title("Outcome of people with > average glucose levels")

plt.show()

Interpretation of results¶

From the plots we can see that Glucose and BMI are the most important features for indicating a high risk of diabetes (highest correlation). Furthermore we can observe that a higher BMI generally correlates with a higher Blood Pressure and that people with a lower than average Glucose have way less chance of diabetes compared to high Glucose levels which makes sense.

Training The model¶

We will be using gradient boosted decision trees. Gradient boosting machine learning methods same as during Lab 5 we will use the implementation from the xgboost package.

Note: the data set we selected for this Assignment has no categorical featuers. this means you will not need to encode categorical features nor to use any encoders. This will make this assignment more simpke than the Lab. This will also allow you to generate explanations from both LIME and SHAP using the same model. You do not need to retrain the model we provided.

In [154]:
#Get Outcome lables
labels = df_diabetes['Outcome']

#Get features
data = df_diabetes.drop('Outcome', axis=1)



# create a train/test split
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(data, labels, train_size=0.70, random_state=7)
print("Train shape: ", X_train.shape)
print("Test shape: ", X_test.shape)
Train shape:  (506, 8)
Test shape:  (218, 8)
In [155]:
# Fit the model
gbtree_model = xgboost.XGBClassifier(learning_rate=0.01)
gbtree_model.fit(X_train, y_train)
Out[155]:
XGBClassifier(base_score=0.5, booster='gbtree', callbacks=None,
              colsample_bylevel=1, colsample_bynode=1, colsample_bytree=1,
              early_stopping_rounds=None, enable_categorical=False,
              eval_metric=None, gamma=0, gpu_id=-1, grow_policy='depthwise',
              importance_type=None, interaction_constraints='',
              learning_rate=0.01, max_bin=256, max_cat_to_onehot=4,
              max_delta_step=0, max_depth=6, max_leaves=0, min_child_weight=1,
              missing=nan, monotone_constraints='()', n_estimators=100,
              n_jobs=0, num_parallel_tree=1, predictor='auto', random_state=0,
              reg_alpha=0, reg_lambda=1, ...)
In [156]:
# Make predictions on test data
y_pred = gbtree_model.predict(X_test)

# Calcualte accuracy on the test real outcomes
accuracy = sklearn.metrics.accuracy_score(y_test, y_pred)
print("Accuracy: %.2f%%" % (accuracy * 100.0))
Accuracy: 77.06%

Part I: Intrinsic Explainability (20 points)¶

To Code (5 points)
A. Create 3 different feature importance plot using Intrinsic Explainability of your model
Hint: Take a look at xgboost [documentation](https://xgboost.readthedocs.io/en/latest/python/python_api.html#).

In [163]:
# Feature importance plot 1
gbtree_model.get_booster().feature_names = list(df_diabetes.columns)[:-1]
plt.rcParams["figure.figsize"] = (8,8)
xgboost.plot_importance(gbtree_model, importance_type="weight")
plt.title("Feature Importance using weight as measurement criteria")
plt.show()
In [103]:
xgboost.plot_importance(gbtree_model, importance_type="cover")
plt.title("Feature Importance using cover as measurement criteria")
plt.show()
In [104]:
xgboost.plot_importance(gbtree_model, importance_type="gain")
plt.title("Feature Importance using gain as measurement criteria")
plt.show()
In [91]:
# Feature importance plot 3

To Answer (15 points)
B. What kind of explanations does these plots provide think about three or for creteria?
C. Who can use these explanations in our context which stakeholder(s) and why (what purpose)?
D. Observe the three plots and evaluate the over all consistency of these intrinsic explanations (are they consistant? what is missing?)

B. The feature importance can be defined in three different ways: using weight, gain or cover.

  • Weight corresponds to the number of times a feature appears in a tree. A higher weight might indicate a higher feature importance since it is more often used to make decisions in the tree.
  • Gain corresponds to the average gain of splits which use the feature. An higher gain corresponds to more information gain after each split (most reduction of entropy) which is a very good indication for feature importance.
  • Cover is the average coverage of splits which use the feature where coverage is defined as the number of samples affected by the split.

C. Medical experts can use these feature importances for more insight in why the model would predict this patient as high-risk / low risk. For example, the model can predict that a patient with a high glucose level has high risk of diabetes. Then by looking at the feature importances, the medical expert can inform and explain why the patient has a high risk of diabetes.

D. The cover and gain feature importance seem to match the most, having Glucose, Age and BMI in the same order of feature importance. Glucose seems to be the most important feature by far for both, also confirmed by the correlation plot in the data exploration part above. Interestingly, the gain measure matches the order of importance given by the correlations the most. The Weight measure is quite different having BMI and DiabetesPedigreeFunction as the two most important features which are lower ranked on the other two measures.

Unfortuately, there are also some limitations of these feature importances. They do not provide insight on the direction of the impact sign (+/-) and they only provide relative importances (not absolute). They are also not that stable, retraining can give different feature importances.


Part II: LIME: Locally Interpetable Model-agnostic Explanations (40 points)¶

LIME Documentation

To Code
A. Implement a LIME explainer (10 points)
B. Use the to generate explanations (with vizualization) on 4 datapoints (10 points)
Note: remember, our dataset does not have categorical features. You do not need to specify any when creating the LIME explainer. Feature names and class names are, however, needed. Take a closer look at the documentation (link above)

In [186]:
explainer = lime.lime_tabular.LimeTabularExplainer(X_train.values, feature_names=list(df_diabetes.columns)[:-1], class_names=['Low Risk','High Risk'])
In [183]:
#Here we have the first 5 test points with corresponding label
test_points = df_diabetes.iloc[list(X_test.index)]
test_points.head()
Out[183]:
Pregnancies Glucose BloodPressure SkinThickness Insulin BMI DiabetesPedigreeFunction Age Outcome
321 1 130 70 13 105 25.9 0.472 22 0
618 11 127 106 0 0 39.0 0.190 51 0
376 4 95 64 0 0 32.0 0.161 31 1
94 0 125 96 0 0 22.5 0.262 21 0
663 2 127 46 21 335 34.4 0.176 22 0
Visualization 0¶
In [187]:
exp = explainer.explain_instance(X_test.values[0], gbtree_model.predict_proba, num_features=5)
exp.show_in_notebook(show_table=True, show_all=True)
Visualization 1¶
In [188]:
exp = explainer.explain_instance(X_test.values[1], gbtree_model.predict_proba, num_features=5)
exp.show_in_notebook(show_table=True, show_all=True)
Visualization 2¶
In [189]:
exp = explainer.explain_instance(X_test.values[2], gbtree_model.predict_proba, num_features=5)
exp.show_in_notebook(show_table=True, show_all=True)
Visualization 3¶
In [190]:
exp = explainer.explain_instance(X_test.values[7], gbtree_model.predict_proba, num_features=5)
exp.show_in_notebook(show_table=True, show_all=True)

To Answer
C. What kind of explanations does these plots provide think about three or for creteria? (5 ponts)
D. Briefely describe how we can read each of the 4 explanations for the points you selected (10 points)
E. Who can use these explanations in our context which stakeholder(s) and why (what purpose)? (5 ponts)

C. Lime stands Local Interpretable Model Agnostic Explanation. The local aspect means that it is used to explain individual predictions of a machine learning model. It provides explainations which medical features contribute to predicting the patient has high / low risk and how much each feature contributes to the decision making.

D. Let's walk through all our 4 datapoints.

Vis 0: The true outcome is Low-Risk. On the left we can see that the model predicts with 0.78 probability that the patient has low risk of having diabetes. Then in the middle of the plot we can observe that this is mainly due to the BMI <= 27.73 and AGE <= 24.00 of the patient. This seems to be a young patient with a healthy weight and so is classified as low risk on having diabetes.

Vis 1: The true outcome of this person is Low-Risk. On the left we can see that the model is quite unsure and predicts with 0.57 probability that the patient has low risk of having diabetes. Then in the middle of the plot we can observe the BMI > 36.80 and AGE > 41.00 of the patient. The higher age and high BMI contribute to classifying this patient as high-risk, but blood pressure and glucose levels seem normal so eventually model classifies it as low-risk which is indeed correct.

Vis 2: The true outcome of this person is High-Risk. On the left we can see that the model is quite unsure and predicts with 0.67 probability that the patient has low risk of having diabetes. The person is somewhat overweight and has high glucose levels, which are indicators for diabetes. Then in the middle of the plot we can observe that the model thinks Glucose <= 99.0 is allocated to being low-risk. Because of this main important feature here, the model eventually predicts it as low-risk but does it wrong since it is actually high risk. Vis 3: The true outcome of this person is High-Risk. On the left we can see that the model is quite sure and predicts with 0.72 probability that the patient has high risk of having diabetes. The person is heavily overweight and has low glucose levels and blood pressure. Then in the middle of the plot we can observe that the model indeed uses Glucose and BMI as features important to predicting the model as high risk. Even though this is a young person, the model got it correct that this person is of high risk obtaining diabetes.

E. Using Lime has the advantage that we are able to analyse locally each persons' medical data and see which features are important in the classification task of the model. Medical experts and doctors can use this to explain results with the patients and exactly pinpoint which medical feature is the reason of high / low risk on diabetes and doctors can so motivate why this person would be selected or not for the clinical drug trial.


Part III: SHAP SHapley Additive Explanation (40 points)¶

SHAP Documentation

To Code
A. Implement a SHAP explainer (5 points)
B. Use the to generate explanations: (15)

  • B1. Global explanations.
  • B2. Local Explanations for the same 4 data points you selected for LIME.
  • B3. Cohort explanations for persons who had more than 2 pragnances.
In [195]:
shap.initjs()

#Create explainer and run on test data
explainer = shap.TreeExplainer(gbtree_model)
shap_values = explainer.shap_values(X_test)
ntree_limit is deprecated, use `iteration_range` or model slicing instead.

B1 Global explainations¶

In [197]:
figure = plt.figure()
shap.summary_plot(shap_values, X_test, plot_type='bar')

Interpretation of plot above¶

This plot shows the global list of important features, from most significant to the least significant one. Glucose is the one with the most predictive power according to the model and age is the second feature with the most predictive power.

In [199]:
shap.summary_plot(shap_values, X_test)

Interpretation of plot above¶

The above plot shows the global list of most important features from top to bottom. Now each dot represents the feature value for a single data instance, a blue dot indicates a low feature value and a red dot indicates a high value. From this we can observe that for glucose a high value would indicate for a higher positive outcome of having diabetes. The same also holds in general for age, BMI and DiabetesPedigreeFunction.

B2 - Local explainations (We use the same data points as before)¶

How to read the local plots¶

The below explanations shows features each contributing to push the model output from the base value (the average model output over the training dataset we passed) to the model output. Features pushing the prediction higher are shown in red, those pushing the prediction lower are in blue

Visualisation 0¶

In [209]:
_waterfall.waterfall_legacy(explainer.expected_value, shap_values[0], X_test.iloc[0,:])
Interpretation of results.¶

The true outcome is Low-Risk. Beacuse of the low BMI and low age of this patient, the prediction is pushed to the left and so this patient is predicted as low risk.

Visualisation 1¶

In [211]:
_waterfall.waterfall_legacy(explainer.expected_value, shap_values[1], X_test.iloc[1,:])
Interpretation of results.¶

The true outcome of this person is Low-Risk. The model is a bit unsure here, high BMI and AGE are pushing the prediction to the right (high-risk) but other factors are pushing it back to the left. Eventually the model predicts -0.298 which is still lower than 0 and so is classified as low risk.

Visualisation 2¶

In [212]:
_waterfall.waterfall_legacy(explainer.expected_value, shap_values[2], X_test.iloc[2,:])
Interpretation of results.¶

The true outcome of this person is High-Risk. The model is a bit unsure here, mainly the low glucose is pushing the prediction far to the left, while the person is a bit older and overweight which pushes it to the right.

Visualisation 3¶

In [213]:
_waterfall.waterfall_legacy(explainer.expected_value, shap_values[7], X_test.iloc[7,:])
Interpretation of results.¶

The true outcome of this person is High-Risk. The model is quite certain here mainly because of the highered glucose levels, which pushes it a lot to the right and ends up with a correct high-risk prediction.

B3. Cohort explanations for persons who had more than 2 pragnances.¶

In [236]:
#Get test points for persons with > 2 pragnances
test_points = df_diabetes.iloc[list(X_test.index)]

pregnant_df = test_points[test_points.Pregnancies > 2]
pregnant_y = pregnant_df['Outcome']
pregnant_x = pregnant_df.drop('Outcome', axis=1)
In [237]:
print(test_points.shape)
print(pregnant_x.shape)
(218, 9)
(116, 8)
In [238]:
shap_p = explainer.shap_values(pregnant_x)
shap.force_plot(explainer.expected_value, shap_p, pregnant_x, show=True)
ntree_limit is deprecated, use `iteration_range` or model slicing instead.
Out[238]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
Interpretation of Results¶

This forceplot shows a combined plot of people who have had at least two pregnancies. It shows the same as the previous waterfall plot but then in an interactive environment using JavaScript where we can see which main features contributed to the classification.

To Answer
C. Brefely Describe what information can we get from each of the plots (B1, B2, and B3) (10 points)
D. Who can use each type of explanations and for what purposes? (10 points)

C. The description of the plots are located below each plot.

D. It is very important for physicians to check if the model is working predicting people to be low-risk diabetes correctly, since they are then viable for the drug trial. If a person would still have diabetes while the model predicts it to be low risk this can cause major issues since the drug can have adverse effects for people with diabetes. So we are interested in False Negatives (recall metric). By incorporating the explainability of the model given above by looking at a specific persons explainability, the physician can use this to explain results with the patients and exactly pinpoint which medical feature is the reason of high / low risk on diabetes and doctors can so motivate why this person would be selected or not for the clinical drug trial. Development team and hospital management might be interested in which global features have the most impact on the classification outcome, for this, plots in B1 are relevant.